#!/usr/bin/env python3

import argparse
import os
import random
import subprocess
import sys
from shutil import copyfile

import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import transforms

sys.path.insert(1, os.path.join(sys.path[0], ('..')))

from colorization.colorization_model import ColorizationModel
from colorization.data.image_directory import ImageDirectory
from colorization.data.transforms import RGBOrGrayToL, ToNumpy
from colorization.modules.colorization_network import ColorizationNetwork
from colorization.util.argparse import nice_help_formatter
from colorization.util.image import images_in_directory, imread, imsave, \
                                    lab_to_rgb, resize, rgb_to_gray, \
                                    rgb_to_lab, to_gray, to_rgb, torch_to_numpy


USAGE = \
"""convert_images [-h|--help]
                      --input-dir INPUT_DIR
                      --output-dir OUTPUT_DIR
                      [--force]
                      [--resize-height RESIZE_HEIGHT]
                      [--resize-width RESIZE_WIDTH]
                      [--color-source-dir COLOR_SOURCE_DIR]
                      [--method METHOD]
                      [--base-network NETWORK_ID]
                      [--annealed-mean-T ANNEALED_MEAN_T]
                      [--model-checkpoint MODEL_CHECKPOINT]
                      [--gpu]
                      [--verbose]
                      {resize,no_color,random_color,predict_color}"""


def _random_color(img, color_source_images):
    while True:
        color_source = imread(random.choice(color_source_images))

        if len(color_source.shape) == 3 and color_source.shape[2] == 3:
            break

    l = img[:, :, :1]
    ab = rgb_to_lab(color_source)[:, :, 1:]

    return lab_to_rgb(np.dstack((l, ab)))


def _predict_color_zhang(args):
    # create dataloaders
    dataset = ImageDirectory(
        args.input_dir,
        return_labels=False,
        return_filenames=True,
        transform=transforms.Compose([
            RGBOrGrayToL(),
            transforms.ToTensor()
        ])
    )

    dataloader = DataLoader(dataset)

    resize_height = args.resize_height or 224
    resize_width = args.resize_width or 224

    dataset_resized = ImageDirectory(
        args.input_dir,
        return_labels=False,
        transform=transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((resize_height, resize_width)),
            ToNumpy(),
            RGBOrGrayToL(),
            transforms.ToTensor()
        ])
    )

    dataloader_resized = DataLoader(dataset_resized, pin_memory=args.gpu)

    # create network
    device = 'cuda' if args.gpu else 'cpu'

    network = ColorizationNetwork(
        base_network=args.base_network,
        annealed_mean_T=args.annealed_mean_T,
        device=device)

    model = ColorizationModel(network)

    # load checkpoint
    model.load_checkpoint(args.model_checkpoint)

    # process images
    with torch.no_grad():
        for (l, filename), l_resized in zip(dataloader, dataloader_resized):
            if args.verbose:
                print("processing '{}'".format(filename[0]))

            l_pred = model.predict(l_resized.to(device))

            l = torch_to_numpy(l)
            ab = resize(torch_to_numpy(l_pred), l.shape[:2])

            out_img = lab_to_rgb(np.dstack((l, ab)))
            out_path = os.path.join(args.output_dir, filename[0])

            imsave(out_path, out_img)


if __name__ == '__main__':
    # parse command line arguments
    parser = argparse.ArgumentParser(formatter_class=nice_help_formatter(),
                                     usage=USAGE)

    parser.add_argument('mode',
        choices=['resize', 'no_color', 'random_color', 'predict_color'])

    parser.add_argument('--input-dir',
                        required=True,
                        help="input image directory")

    parser.add_argument('--output-dir',
                        required=True,
                        help="output image directory")

    parser.add_argument('--force',
                        action='store_true',
                        help="overwrite output directory")

    parser.add_argument('--resize-height',
                        type=int,
                        help=str("if mode is 'resize', resize images to this "
                                 "height, if mode is 'predict_color', resize "
                                 "network inputs to this height (output images "
                                 "will still have the same dimensions as input "
                                 "images), in the latter case this is 224 by "
                                 "default"))

    parser.add_argument('--resize-width',
                        type=int,
                        help="see --resize-height")

    parser.add_argument('--color-source-dir',
                        help=str("only meaningful in conjunction with "
                                 "'random_color', random colors are extracted "
                                 "from randomly chosen images in this "
                                 "directory "))

    parser.add_argument('--method',
                        choices=['zhang', 'larsson', 'iizuka'],
                        default='zhang',
                        help=str("only meaningful in conjunction with "
                                 "'predict_color'"))

    parser.add_argument('--base-network',
                        choices=['vgg', 'deeplab'],
                        default='vgg',
                        help=str("only meaningful in conjuction with "
                                 "'predict_color', network architecture "
                                 "id of prediction model (default: %(default)s)"))

    parser.add_argument('--annealed-mean-T',
                        type=float,
                        default=0.38,
                        help=str("only meaningful in conjunction with "
                                 "'predict_color', annealed mean decoding "
                                 "temperature parameter (default: %(default)s)"))

    parser.add_argument('--model-checkpoint',
                        help=str("only meaningful in conjunction with "
                                 "'predict_color', the color prediction "
                                 "model is loaded from this checkpoint"))

    parser.add_argument('--gpu',
                        action='store_true',
                        help=str("only meaningful in conjunction with "
                                 "'predict_color', run prediction on GPU"))

    parser.add_argument('--verbose',
                        action='store_true',
                        help="display progress")

    args = parser.parse_args()

    # validate arguments
    if args.mode == 'resize':
        if args.resize_height is None or args.resize_width is None:
            err = "dimensions must be specified if mode is 'resize'"
            raise ValueError(err)
    elif args.mode == 'random_color' and args.color_source_dir is None:
        err = "--color-source-dir must be specified if mode is 'random_color'"
        raise ValueError(err)
    elif (args.mode == 'predict_color' and
          args.method != 'larsson' and
          args.model_checkpoint is None):
        err = "--model-checkpoint must be specified if mode is 'predict_color'"
        raise ValueError(err)

    # create output directory
    if os.path.exists(args.output_dir):
        if not args.force and os.listdir(args.output_dir):
            raise ValueError("'{}' already exists".format(args.output_dir))
    else:
        os.mkdir(args.output_dir)

    # convert images
    if args.mode == 'predict_color' and args.method == 'zhang':
        _predict_color_zhang(args)
    else:
        input_dir = os.path.abspath(args.input_dir)
        output_dir = os.path.abspath(args.output_dir)

        if args.mode == 'predict_color' and args.method == 'larsson':
            import autocolorize

            classifier = autocolorize.load_default_classifier()

            def transform(in_path, out_path):
                img = to_gray(imread(in_path))
                img = img.astype(np.float64) / 255
                img = autocolorize.colorize(img, classifier=classifier)
                imsave(out_path, img)

        elif args.mode == 'predict_color' and args.method == 'iizuka':
            os.chdir(os.path.dirname(args.model_checkpoint))

            def transform(in_path, out_path):
                cmd = "th colorize.lua {} {}"
                subprocess.run(cmd.format(in_path, out_path),
                               shell=True,
                               check=True)

        elif args.mode == 'resize':
            h, w = args.resize_height, args.resize_width

            def transform(in_path, out_path):
                img = imread(in_path)
                imsave(out_path, resize(img, (h, w)))

        elif args.mode == 'no_color':
            def transform(in_path, out_path):
                img = to_gray(imread(in_path))
                imsave(out_path, img)

        elif args.mode == 'random_color':
            color_source_images = images_in_directory(args.color_source_dir)

            def transform(in_path, out_path):
                img = to_gray(imread(in_path))
                imsave(out_path, _random_color(img, color_source_images))

        for filename in images_in_directory(input_dir, exclude_root=True):
            if args.verbose:
                print("processing '{}'".format(filename))

            in_path = os.path.join(input_dir, filename)
            out_path = os.path.join(output_dir, filename)

            transform(in_path, out_path)

    # copy labels
    label_file = os.path.join(args.input_dir, ImageDirectory.LABEL_FILENAME)

    if os.path.exists(label_file):
        out_label_file = os.path.join(
            args.output_dir, ImageDirectory.LABEL_FILENAME)

        copyfile(label_file, out_label_file)
